Add RoPE embeddings#5481
Conversation
| self.cos_cached = nnx.Variable(jnp.cos(freqs_outer).astype(dtype)) | ||
| self.sin_cached = nnx.Variable(jnp.sin(freqs_outer).astype(dtype)) |
There was a problem hiding this comment.
In multiple JAX repositories I saw that sin and cos are constructed from segment_pos (absolute token positions):
- our gemma example
- bonsai and jax-llm-examples
- gemma4
The idea is that if the input sequence is not packed, but just padded, our implementation would mostly work as expected. In case of packed sequence where multiple sentences are inserted in the same sequence:
[<bos>, 1, 2, 3, <eos>, <bos>, 4, 5, <eos>, <pad>, <pad>]
then the input positions (segment_pos) would be:
[0, 1, 2, 3, 4, 0, 1, 2, 3, 0, 0]
so, RoPE may be computed differently.
On the other hand, current MHA.__call__ does not accept any positions arg, so we can't pass it to RoPE...
In PyTorch, basic implementation does something similar to your implementation, but cos and sin cached from the input x.
There was a problem hiding this comment.
My implementation was based on the one in equinox, which doesn't pass explicit positions.
- We could give
MHAa new optionalpositionsargument which we thread through to the attention_fn. But this could break users' custom attention_fn implementations if they weren't expecting the argument. - We could make a MHA subclass with a call method that accepts a
positionsargument. Say,PackedMHA? This is a little more ugly, but wouldn't be a breaking change.
There was a problem hiding this comment.
Can we do the following?
class MHA:
def __call__(self, ..., input_positions: Array | None = None):
...
attn_kwargs = {}
if input_positions is not None:
attn_kwargs["input_positions"] = input_positions
x = self.attention_fn(
query,
key,
value,
mask=mask,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
broadcast_dropout=self.broadcast_dropout,
deterministic=deterministic,
dtype=self.dtype,
precision=self.precision,
module=self if sow_weights else None,
is_causal=is_causal,
**attn_kwargs,
)
def dot_product_attention_with_rope(..., rope, input_positions: Array | None = None, **kwargs)
# handle properly input_positions is None
# input_positions: (B, S)
apply = jax.vmap(rope, in_axes=(-2, -1), out_axes=(-2, -1))
query = apply(query, input_positions)
key = apply(key, input_positions)
...There was a problem hiding this comment.
I like that! So users that aren't already calling MHA with input positions won't have their custom attention_fn break!
There was a problem hiding this comment.
Should we still keep the cached sin and cos vectors for use when input_positions=None? Might be slightly faster for the non-packed case, as we wouldn't need to rebuild them. But the interface could be nicer if we just rebuilt them every time, so that the RoPE constructor wouldn't need max_seq_len and embedding_size arguments (dynamically getting them from the input x). What do you think?
There was a problem hiding this comment.
We can cache them after the first call ?
There was a problem hiding this comment.
So we'd check if the input_positions is the same as the cached one at each call. If so, we'd use the cache (populated on the first call) and otherwise, we'd generate it dynamically?
To work in tree mode, we'd need to create the Variables for the cache at initialization. But then we could write to these variables on the first __call__.
There was a problem hiding this comment.
If we're doing cross attention, things get trickier. The sequence lengths for the keys might not be the same as the lengths for the values. Caching a single value wouldn't account for this. Caching everything would break if the packing of each batch might be different.
There was a problem hiding this comment.
Good point, actually, not sure if seen RoPE for cross-attention, but it makes sense that we may need to have q_positions and kv_positions or even k_positions, v_positions. But it can become rather cumbersome the API finally
There was a problem hiding this comment.
I ended up at a compromise API. If the user wants caching, they can specify the max size during initialization, like I had before. Otherwise, we don't cache, and compute the rotation matrices on the fly.
This also means we don't have to use QDD for the cache, which might become deprecated if we end up switching to hijax variables and QDD is removed. One less thing to worry about down the road.
66c99da to
319550d
Compare
This PR adds support for RoPE. Specifically, a new function
dot_product_attention_with_ropecan be used as theattention_fnargument fornnx.MultiHeadAttention.